package wit_types

import (
	"runtime"
	"unsafe"
	"wit_component/wit_async"
	"wit_component/wit_runtime"
)

type StreamVtable[T any] struct {
	Size         uint32
	Align        uint32
	Read         func(handle int32, items unsafe.Pointer, length uint32) uint32
	Write        func(handle int32, items unsafe.Pointer, length uint32) uint32
	CancelRead   func(handle int32) uint32
	CancelWrite  func(handle int32) uint32
	DropReadable func(handle int32)
	DropWritable func(handle int32)
	Lift         func(src unsafe.Pointer) T
	Lower        func(pinner *runtime.Pinner, value T, dst unsafe.Pointer)
}

type StreamReader[T any] struct {
	vtable        *StreamVtable[T]
	handle        *wit_runtime.Handle
	writerDropped bool
}

func (self *StreamReader[T]) WriterDropped() bool {
	return self.writerDropped
}

func (self *StreamReader[T]) Read(dst []T) uint32 {
	handle := self.handle.Use()

	if self.writerDropped {
		return 0
	}

	pinner := runtime.Pinner{}
	defer pinner.Unpin()

	var buffer unsafe.Pointer
	if self.vtable.Lift == nil {
		buffer = unsafe.Pointer(unsafe.SliceData(dst))
	} else {
		buffer = wit_runtime.Allocate(
			&pinner,
			uintptr(self.vtable.Size*uint32(len(dst))),
			uintptr(self.vtable.Align),
		)
	}
	pinner.Pin(buffer)

	code, count := wit_async.FutureOrStreamWait(self.vtable.Read(handle, buffer, uint32(len(dst))), handle)

	if code == wit_async.RETURN_CODE_DROPPED {
		self.writerDropped = true
	}

	if self.vtable.Lift != nil {
		for i := 0; i < int(count); i++ {
			dst[i] = self.vtable.Lift(unsafe.Add(buffer, i*int(self.vtable.Size)))
		}
	}

	return count
}

func (self *StreamReader[T]) Drop() {
	handle := self.handle.TakeOrNil()
	if handle != 0 {
		self.vtable.DropReadable(handle)
	}
}

func (self *StreamReader[T]) TakeHandle() int32 {
	return self.handle.Take()
}

func MakeStreamReader[T any](vtable *StreamVtable[T], handleValue int32) *StreamReader[T] {
	handle := wit_runtime.MakeHandle(handleValue)
	value := &StreamReader[T]{vtable, handle, false}
	runtime.AddCleanup(value, func(_ int) {
		handleValue := handle.TakeOrNil()
		if handleValue != 0 {
			vtable.DropReadable(handleValue)
		}
	}, 0)
	return value
}

type StreamWriter[T any] struct {
	vtable        *StreamVtable[T]
	handle        *wit_runtime.Handle
	readerDropped bool
}

func (self *StreamWriter[T]) ReaderDropped() bool {
	return self.readerDropped
}

func (self *StreamWriter[T]) Write(items []T) uint32 {
	handle := self.handle.Use()

	if self.readerDropped {
		return 0
	}

	pinner := runtime.Pinner{}
	defer pinner.Unpin()

	writeCount := uint32(len(items))

	var buffer unsafe.Pointer
	if self.vtable.Lower == nil {
		buffer = unsafe.Pointer(unsafe.SliceData(items))
		pinner.Pin(buffer)
	} else {
		buffer = wit_runtime.Allocate(
			&pinner,
			uintptr(self.vtable.Size*writeCount),
			uintptr(self.vtable.Align),
		)
		for index, item := range items {
			self.vtable.Lower(&pinner, item, unsafe.Add(buffer, index*int(self.vtable.Size)))
		}
	}

	code, count := wit_async.FutureOrStreamWait(self.vtable.Write(handle, buffer, writeCount), handle)

	// TODO: restore handles to any unwritten resources, streams, or futures

	if code == wit_async.RETURN_CODE_DROPPED {
		self.readerDropped = true
	}

	return count
}

func (self *StreamWriter[T]) WriteAll(items []T) uint32 {
	offset := uint32(0)
	count := uint32(len(items))
	for offset < count && !self.readerDropped {
		offset += self.Write(items[offset:])
	}
	return offset
}

func (self *StreamWriter[T]) Drop() {
	handle := self.handle.TakeOrNil()
	if handle != 0 {
		self.vtable.DropWritable(handle)
	}
}

func MakeStreamWriter[T any](vtable *StreamVtable[T], handleValue int32) *StreamWriter[T] {
	handle := wit_runtime.MakeHandle(handleValue)
	value := &StreamWriter[T]{vtable, handle, false}
	runtime.AddCleanup(value, func(_ int) {
		handleValue := handle.TakeOrNil()
		if handleValue != 0 {
			vtable.DropReadable(handleValue)
		}
	}, 0)
	return value
}
